# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import unittest
from datetime import timedelta

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch._C._distributed_c10d import Backend as C10dBackend
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._mesh_layout import _MeshLayout as _Layout
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
from torch.distributed.distributed_c10d import (
    _get_default_group,
    _world,
    get_global_rank,
    get_world_size,
    init_process_group,
    is_initialized,
    new_group,
    ProcessGroup,
)
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._collective_utils import (
    mesh_broadcast,
    mesh_scatter,
    unpad_tensor,
)
from torch.distributed.tensor.placement_types import _Partial, Shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    with_comms,
)
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeStore
from torch.utils._typing_utils import not_none


device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_count = torch.accelerator.device_count()

try:
    import torch._C._distributed_c10d.ProcessGroupNCCL

    _NCCL_AVAILABLE = True
except ImportError:
    _NCCL_AVAILABLE = False


def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
    os.environ["MASTER_ADDR"] = addr
    os.environ["MASTER_PORT"] = port
    os.environ["WORLD_SIZE"] = f"{world_size}"
    os.environ["RANK"] = f"{rank}"
    if local_rank != -1:
        os.environ["LOCAL_RANK"] = f"{local_rank}"


@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
class DeviceMeshTestGlooBackend(DTensorTestBase):
    @property
    def backend(self):
        return "gloo"

    @with_comms
    def test_device_mesh_reuse_default_group(self):
        mesh = init_device_mesh(self.device_type, (self.world_size,))
        mesh_group = mesh.get_group()
        default_group = _get_default_group()
        if torch.cuda.is_available():
            self.assertNotEqual(mesh_group, default_group)
            self.assertEqual(get_world_size(mesh_group), get_world_size(default_group))
        else:
            self.assertEqual(mesh_group, default_group)


class DeviceMeshSetDeviceTest(DTensorTestBase):
    @property
    def world_size(self):
        return 4

    @skip_if_lt_x_gpu(4)
    def test_manual_set_device(self):
        mesh_tensor = torch.arange(4).reshape(2, 2)
        self.assertTrue(not is_initialized())

        # Set the device on each process before DeviceMesh constructor,
        # and device to be different than the default world rank
        torch.accelerator.set_device_index((self.rank + 2) % self.world_size)
        _set_env_var(world_size=self.world_size, rank=self.rank)
        DeviceMesh(self.device_type, mesh_tensor)
        self.assertTrue(is_initialized())

        # check that the device is set to the correct device
        # and respect the previous set_device calls
        self.assertEqual(
            torch.accelerator.current_device_idx(), (self.rank + 2) % self.world_size
        )
        self.destroy_pg()

    @skip_if_lt_x_gpu(4)
    def test_auto_set_device_from_local_rank(self):
        mesh_tensor = torch.arange(4).reshape(2, 2)
        self.assertTrue(not is_initialized())
        # set the local rank to be different than the default world rank,
        # DeviceMesh should respect LOCAL_RANK env var if it's set
        local_rank = (self.rank + 1) % self.world_size

        _set_env_var(
            world_size=self.world_size,
            rank=self.rank,
            local_rank=local_rank,
        )
        DeviceMesh(self.device_type, mesh_tensor)
        self.assertTrue(is_initialized())

        # check that the device is set to the correct device
        # and respect the LOCAL_RANK env var
        self.assertEqual(torch.accelerator.current_device_idx(), local_rank)
        self.destroy_pg()

    @skip_if_lt_x_gpu(4)
    def test_auto_set_device_from_heuristic(self):
        mesh_tensor = torch.arange(4).reshape(2, 2)
        self.assertTrue(not is_initialized())

        _set_env_var(
            world_size=self.world_size,
            rank=self.rank,
        )
        with self.assertWarnsRegex(
            UserWarning, "It seems like you did not set/select the default device"
        ):
            DeviceMesh(self.device_type, mesh_tensor)
        self.assertTrue(is_initialized())

        # check that the device is set to the correct device
        self.assertEqual(torch.accelerator.current_device_idx(), self.rank)
        self.destroy_pg()


class DeviceMeshTest(DTensorTestBase):
    @property
    def world_size(self):
        return 4

    @skip_if_lt_x_gpu(4)
    def test_init_process_group(self):
        mesh_tensor = torch.arange(4).reshape(2, 2)
        self.assertTrue(not is_initialized())
        _set_env_var(world_size=self.world_size, rank=self.rank)
        DeviceMesh(self.device_type, mesh_tensor)
        self.assertTrue(is_initialized())
        self.destroy_pg(self.rank)

    @with_comms
    @skip_if_lt_x_gpu(4)
    def test_assert_invalid_mesh_tensor(self):
        mesh = torch.arange(self.world_size).to(self.rank)
        with self.assertRaises(ValueError):
            DeviceMesh(self.device_type, mesh)

    @with_comms()
    def test_2d_mesh_non_eager_init_subgroup(self):
        mesh_shape = (2, self.world_size // 2)
        mesh_2d = init_device_mesh(self.device_type, mesh_shape)

        self.assertEqual(mesh_2d.get_group(0).bound_device_id, None)
        self.assertEqual(mesh_2d.get_group(1).bound_device_id, None)

    # TODO: need to refactor the other tests in this file to test both
    # eager_init=True and eager_init=False scenarios.
    @with_comms(eager_init=True)
    def test_2d_mesh_eager_init_subgroup(self):
        mesh_shape = (2, self.world_size // 2)
        mesh_2d = init_device_mesh(self.device_type, mesh_shape)

        # when eager init is used, the subgroup is created from nccl comm split and
        # there would be bound_device_id immediately assigned for the subgroup.
        if self.backend == "nccl":
            curr_device = torch.cuda.current_device()
            self.assertEqual(mesh_2d.get_group(0).bound_device_id.index, curr_device)
            self.assertEqual(mesh_2d.get_group(1).bound_device_id.index, curr_device)

    @with_comms()
    def test_get_group_and_get_all_groups(self):
        mesh_shape = (2, self.world_size // 2)
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
        )

        tp_mesh = mesh_2d["tp"]
        dp_mesh = mesh_2d["dp"]

        self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp"))
        self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp"))

        self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group())
        self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group())

        groups = mesh_2d.get_all_groups()
        self.assertEqual(len(groups), 2)
        self.assertTrue(tp_mesh.get_group() in groups)
        self.assertTrue(dp_mesh.get_group() in groups)

    @with_comms
    def test_get_local_rank_raises_exception(self):
        mesh_shape = (2, self.world_size // 2)
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
        )

        with self.assertRaisesRegex(
            RuntimeError,
            "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
        ):
            mesh_2d.get_local_rank()

    @with_comms
    def test_get_local_rank(self):
        mesh_shape = (2, self.world_size // 2)
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
        )
        self.assertEqual(mesh_2d.get_local_rank("dp"), mesh_2d.get_local_rank(0))
        self.assertEqual(mesh_2d.get_local_rank("tp"), mesh_2d.get_local_rank(1))

        dp_mesh = mesh_2d["dp"]
        tp_mesh = mesh_2d["tp"]
        self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp"))
        self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp"))

        # Verify flattened mesh local rank correctness.
        flattened_mesh = mesh_2d["dp", "tp"]._flatten()
        self.assertEqual(flattened_mesh.get_local_rank(), self.rank)

    @with_comms
    def test_device_mesh_2d(self):
        mesh_tensor = torch.arange(4).reshape(2, 2)
        # construct a device mesh for self.device_type
        mesh = DeviceMesh(self.device_type, mesh_tensor)

        # check all dim groups
        dim_to_subgroups = mesh.get_all_groups()

        expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
        for dim, dim_group in enumerate(dim_to_subgroups):
            self.assertTrue(dim < 2)
            dim_ranks = expected_ranks_by_dim[dim]

            dim_group_size = get_world_size(dim_group)
            self.assertIsInstance(dim_group, ProcessGroup)
            self.assertEqual(dim_group_size, 2)
            global_ranks = [
                get_global_rank(dim_group, i) for i in range(dim_group_size)
            ]
            current_rank_expected_group_ranks = (
                dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]
            )
            self.assertEqual(global_ranks, current_rank_expected_group_ranks)

    @with_comms
    def test_device_mesh_init_backend(self):
        mesh = DeviceMesh(
            self.device_type, torch.arange(10), _init_backend=False, _rank=5
        )

        with self.assertRaisesRegex(RuntimeError, "process groups not initialized!"):
            mesh.get_group()

        # coordinates should always been populated when init_backend is False, as whenever
        # we call init_backend we should make sure the default pg already created
        self.assertEqual(mesh.get_coordinate(), [5])

    def test_fake_pg_device_mesh(self):
        fake_store = FakeStore()
        init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
        device_type = (
            torch.accelerator.current_accelerator().type
            if torch.accelerator.is_available()
            else "cpu"
        )
        mesh = DeviceMesh(device_type, torch.arange(self.world_size))

        local_tensor = torch.randn(2, 8)
        global_tensor = funcol.all_gather_tensor(
            local_tensor, gather_dim=0, group=(mesh, 0)
        ).wait()
        self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))

    @with_comms
    def test_from_group_with_global_pg(self):
        # Simple test: check `from_group` from a mesh pg vs. directly
        # initializing via `init_device_mesh`
        ref_global_mesh = init_device_mesh(self.device_type, (self.world_size,))
        mesh_pg = ref_global_mesh.get_group()
        global_mesh = DeviceMesh.from_group(mesh_pg, self.device_type)
        self.assertEqual(ref_global_mesh, global_mesh)
        self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names)
        self.assertEqual(
            ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
        )
        # Check when `mesh` is passed as well
        global_mesh = DeviceMesh.from_group(
            mesh_pg, self.device_type, mesh=torch.arange(self.world_size)
        )
        self.assertEqual(ref_global_mesh, global_mesh)
        self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names)
        self.assertEqual(
            ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
        )

    @with_comms
    def test_from_group_with_invalid_mesh(self):
        global_pg = _get_default_group()
        global_pg_size = global_pg.size()
        assert global_pg_size == 4, "Test assumes global world size of 4"
        invalid_mesh = [[0, 1], [2, 3]]  # 2D mesh when we need 1D
        regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]"
        with self.assertRaisesRegex(ValueError, regex):
            DeviceMesh.from_group(
                global_pg, device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1")
            )

        device_mesh = init_device_mesh(self.device_type, (2, 2))
        groups = device_mesh.get_all_groups()
        invalid_mesh = (0, 1, 2, 3)  # 1D mesh when we need 2D
        regex = r"Expects mesh with ndim equal to number of ProcessGroups but got mesh \[0, 1, 2, 3\] and 2 ProcessGroups"
        with self.assertRaisesRegex(ValueError, regex):
            DeviceMesh.from_group(
                groups, self.device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1")
            )

    def test_raises_invalid_device_type(self):
        with self.assertRaisesRegex(
            RuntimeError,
            "Device type with index is not supported",
        ):
            # test init_device_mesh with an invalid device type that contains a GPU index
            mesh_shape = (2, self.world_size // 2)
            init_device_mesh(
                f"{device_type}:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
            )

    @with_comms
    def test_get_root_mesh_multiple_independent_meshes(self):
        # regression test for issue #163330
        # when creating multiple independent device meshes and slicing them,
        # get_root_mesh should return the correct parent mesh for each submesh
        mesh1 = init_device_mesh(
            self.device_type,
            (2, 2),
            mesh_dim_names=("dp", "tp"),
        )
        mesh1_dp = mesh1["dp"]
        mesh1_tp = mesh1["tp"]

        mesh2 = init_device_mesh(
            self.device_type,
            (2, 2),
            mesh_dim_names=("dim1", "dim2"),
        )
        mesh2_dim1 = mesh2["dim1"]
        mesh2_dim2 = mesh2["dim2"]

        self.assertEqual(_mesh_resources.get_root_mesh(mesh1_dp), mesh1)
        self.assertEqual(_mesh_resources.get_root_mesh(mesh1_tp), mesh1)
        self.assertEqual(_mesh_resources.get_root_mesh(mesh2_dim1), mesh2)
        self.assertEqual(_mesh_resources.get_root_mesh(mesh2_dim2), mesh2)

        self.assertNotEqual(_mesh_resources.get_root_mesh(mesh1_dp), mesh2)
        self.assertNotEqual(_mesh_resources.get_root_mesh(mesh1_tp), mesh2)


class DeviceMeshTestNDim(DTensorTestBase):
    @property
    def world_size(self):
        return 8

    @with_comms
    def test_device_mesh_nd(self):
        # construct a device mesh for self.device_type
        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
        mesh = DeviceMesh(self.device_type, mesh_tensor)

        # check all dim groups
        dim_to_subgroups = mesh.get_all_groups()

        for dim, dim_group in enumerate(dim_to_subgroups):
            self.assertTrue(dim < mesh_tensor.ndim)
            dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2)

            dim_group_size = get_world_size(dim_group)
            self.assertIsInstance(dim_group, ProcessGroup)
            self.assertEqual(dim_group_size, 2)
            global_ranks = [
                get_global_rank(dim_group, i) for i in range(dim_group_size)
            ]
            for ranks in dim_ranks:
                if self.rank in ranks:
                    self.assertEqual(global_ranks, ranks.tolist())

    @with_comms
    def test_device_mesh_hash(self):
        mesh_tensor_2d = torch.arange(8).reshape(4, 2)
        mesh = DeviceMesh(self.device_type, mesh_tensor_2d)
        mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d)
        self.assertEqual(hash(mesh), hash(mesh2))
        mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2)
        mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d)
        self.assertNotEqual(hash(mesh), hash(mesh3))
        self.assertNotEqual(hash(mesh2), hash(mesh3))

    @with_comms
    def test_get_local_rank_3d(self):
        """
        If we have a 3D mesh and we want to apply dp, pp, tp to it,
        mesh_dim_names = ["dp", "pp", "tp"], and the mesh tensor would be:
        mesh_3d_tensor = [
            [
                [0, 1],
                [2, 3],
            ],
            [
                [4, 5],
                [6, 7],
            ]

        ]
        """
        mesh_shape = (2, 2, 2)
        mesh_3d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=("dp", "pp", "tp")
        )

        # tp_rank_0: [0, 2, 4, 6], tp_rank_1: [1, 3, 5, 7]
        tp_rank = mesh_3d.get_local_rank("tp")
        expected_tp_rank = self.rank % 2
        self.assertEqual(tp_rank, expected_tp_rank)

        # pp_rank_0: [0, 1, 4, 5], pp_rank_1: [2, 3, 6, 7]
        pp_rank = mesh_3d.get_local_rank("pp")
        expected_pp_rank = 0 if self.rank % 4 <= 1 else 1
        self.assertEqual(pp_rank, expected_pp_rank)

        # dp_rank_0: [0, 1, 2, 3], dp_rank_1: [4, 5, 6, 7]
        dp_rank = mesh_3d.get_local_rank("dp")
        expected_dp_rank = self.rank // 4
        self.assertEqual(dp_rank, expected_dp_rank)

    @with_comms
    def test_device_mesh_parent_child_hash(self):
        mesh_2d = init_device_mesh(
            self.device_type, (2, self.world_size // 2), mesh_dim_names=("DP", "TP")
        )

        mesh_group_1 = torch.arange(0, self.world_size // 2)
        mesh_group_2 = torch.arange(self.world_size // 2, self.world_size)
        ep_mesh_1 = DeviceMesh(self.device_type, mesh_group_1)
        ep_mesh_2 = DeviceMesh(self.device_type, mesh_group_2)
        ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2
        # ep_mesh is considered different from mesh_2d["TP"]
        self.assertEqual(
            mesh_2d["TP"].mesh.flatten().tolist(), ep_mesh.mesh.flatten().tolist()
        )
        self.assertEqual(mesh_2d["TP"]._layout, ep_mesh._layout)
        self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape)
        self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type)
        self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names)
        self.assertEqual(mesh_2d["TP"]._thread_id, ep_mesh._thread_id)
        self.assertNotEqual(hash(mesh_2d["TP"]), hash(ep_mesh))
        self.assertNotEqual(mesh_2d["TP"], ep_mesh)

        another_mesh_1 = DeviceMesh(self.device_type, mesh_group_1)
        another_mesh_2 = DeviceMesh(self.device_type, mesh_group_2)
        another_mesh = (
            another_mesh_1 if self.rank < self.world_size // 2 else another_mesh_2
        )
        # another_mesh is considered the same as ep_mesh
        self.assertEqual(ep_mesh._flatten_rank_map, another_mesh._flatten_rank_map)
        self.assertEqual(ep_mesh._layout, another_mesh._layout)
        self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape)
        self.assertEqual(ep_mesh.device_type, another_mesh.device_type)
        self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names)
        self.assertEqual(ep_mesh._thread_id, another_mesh._thread_id)
        self.assertEqual(hash(ep_mesh), hash(another_mesh))
        self.assertEqual(ep_mesh, another_mesh)

    @with_comms
    def test_from_group_with_mesh_shape_3d(self):
        """Tests ``from_group`` when passing ``mesh_shape`` as 3D."""
        # Consider the following 3D scenario and we need to create the 2D HSDP mesh from it.
        # - (2, 2, 2) ("dp_replicate", "dp_shard", "tp") mesh
        mesh_shape = (2, 2, 2)
        mesh_dim_names = ("dp_replicate", "dp_shard", "tp")
        ref_mesh = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        dp_shard_group = ref_mesh["dp_shard"].get_group()
        dp_replicate_group = ref_mesh["dp_replicate"].get_group()

        dp_mesh = DeviceMesh.from_group(
            [dp_replicate_group, dp_shard_group],
            self.device_type,
            mesh=ref_mesh.mesh[:, :, ref_mesh.get_local_rank(mesh_dim="tp")],
            mesh_dim_names=("dp_replicate", "dp_shard"),
        )

        ref_mesh_dp_dim_group_names = ref_mesh._dim_group_names[:2]
        self.assertEqual(ref_mesh_dp_dim_group_names, dp_mesh._dim_group_names[:2])
        # Cannot check directly for mesh equality since parent meshes are not
        # the same since the ref's parent mesh is 3D
        self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh)
        self.assertEqual(
            dp_mesh["dp_replicate"]._dim_group_names,
            ref_mesh["dp_replicate"]._dim_group_names,
        )
        self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh)
        self.assertEqual(
            dp_mesh["dp_shard"]._dim_group_names,
            ref_mesh["dp_shard"]._dim_group_names,
        )

    @with_comms()
    def test_from_group_with_mesh_shape_2d(self):
        """Tests ``from_group`` when passing ``mesh_shape`` as 2D."""
        # Consider the following scenario where the process group has been created,
        # but we need to create the 2D HSDP mesh from it later in the program.
        mesh_shape = (2, 4)
        mesh_dim_names = ("dp_replicate", "dp_shard")
        ref_mesh = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
        # and assign the correct shard group to each rank
        shard_rank_lists = (
            list(range(self.world_size // 2)),
            list(range(self.world_size // 2, self.world_size)),
        )
        shard_groups = (
            new_group(shard_rank_lists[0]),
            new_group(shard_rank_lists[1]),
        )
        current_shard_group = (
            shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
        )

        # Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7))
        # and assign the correct replicate group to each rank
        current_replicate_group = None
        shard_factor = len(shard_rank_lists[0])
        for i in range(self.world_size // 2):
            replicate_group_ranks = list(range(i, self.world_size, shard_factor))
            replicate_group = new_group(replicate_group_ranks)
            if self.rank in replicate_group_ranks:
                current_replicate_group = replicate_group

        dp_mesh = DeviceMesh.from_group(
            [not_none(current_replicate_group), current_shard_group],
            self.device_type,
            mesh=ref_mesh.mesh,
            mesh_dim_names=("dp_replicate", "dp_shard"),
        )

        for mesh_dim_group, ref_mesh_dim_group in zip(
            dp_mesh.get_all_groups(), ref_mesh.get_all_groups()
        ):
            mesh_dim_group_ranks = dist.get_process_group_ranks(mesh_dim_group)
            ref_mesh_dim_group_ranks = dist.get_process_group_ranks(ref_mesh_dim_group)
            self.assertEqual(mesh_dim_group_ranks, ref_mesh_dim_group_ranks)
        # check both the 2d mesh and the submeshes are exactly the same.
        self.assertEqual(dp_mesh, ref_mesh)
        self.assertEqual(dp_mesh["dp_replicate"], ref_mesh["dp_replicate"])
        self.assertEqual(dp_mesh["dp_shard"], ref_mesh["dp_shard"])


class InitDeviceMeshTest(DTensorTestBase):
    @property
    def world_size(self):
        return 8

    @with_comms
    def test_init_device_mesh(self):
        mesh_shape = (2, 4)
        mesh_dim_names = ("DP", "TP")
        ref_mesh = DeviceMesh(
            self.device_type,
            torch.arange(8).view(mesh_shape),
            mesh_dim_names=mesh_dim_names,
        )

        # test init_device_mesh with mesh_dim_names
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )
        self.assertEqual(mesh_2d, ref_mesh)
        self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names)

    @with_comms
    def test_raises_duplicate_mesh_dim_names(self):
        with self.assertRaisesRegex(
            RuntimeError,
            "Each mesh_dim_name must be unique.",
        ):
            init_device_mesh(
                self.device_type,
                (2, 4),
                mesh_dim_names=["dp", "dp"],
            )

    @with_comms
    def test_raises_mesh_shape_mesh_dim_names_mismatch(self):
        with self.assertRaisesRegex(
            RuntimeError,
            "mesh_shape and mesh_dim_names should have same length!",
        ):
            init_device_mesh(
                self.device_type,
                (8,),
                mesh_dim_names=["dp", "tp"],
            )

    def _test_backend_override_argument_dict_with_idx_and_backend(self):
        opts = FakeProcessGroup.Options()
        opts.fake_option = 42

        mesh = init_device_mesh(
            self.device_type,
            (2, 2, 2),
            mesh_dim_names=("dp", "tp", "cp"),
            backend_override={0: "fake", 2: ("fake", opts)},
        )

        def get_opts(mesh: DeviceMesh, dim_idx: int) -> C10dBackend.Options:
            return (
                mesh.get_group(dim_idx)
                ._get_backend(torch.device(f"{self.device_type}:{self.rank}"))
                .options
            )

        # Fake pg only have BackendType as BackendType::CUSTOM.
        self.assertEqual(mesh.get_group(0)._get_backend_name(), "custom")
        self.assertNotEqual(mesh.get_group(1)._get_backend_name(), "custom")
        self.assertEqual(mesh.get_group(2)._get_backend_name(), "custom")

        self.assertIsNone(get_opts(mesh, 0))
        self.assertEqual(get_opts(mesh, 2).fake_option, 42)

        dp_tp_mesh = mesh["dp", "tp"]._flatten()
        dp_cp_mesh = mesh["dp", "cp"]._flatten(backend_override="fake")
        tp_cp_mesh = mesh["tp", "cp"]._flatten(backend_override=("fake", opts))

        self.assertNotEqual(dp_tp_mesh.get_group(0)._get_backend_name(), "custom")
        self.assertEqual(dp_cp_mesh.get_group(0)._get_backend_name(), "custom")
        self.assertEqual(tp_cp_mesh.get_group(0)._get_backend_name(), "custom")

        self.assertIsNone(get_opts(dp_cp_mesh, 0))
        self.assertEqual(get_opts(tp_cp_mesh, 0).fake_option, 42)

    @with_comms
    def test_backend_override_argument_dict_with_idx_and_backend_lazy(self):
        self._test_backend_override_argument_dict_with_idx_and_backend()

    @with_comms(eager_init=True)
    def test_backend_override_argument_dict_with_idx_and_backend_eager(self):
        self._test_backend_override_argument_dict_with_idx_and_backend()

    @with_comms(backend="fake")
    def test_backend_override_argument_dict_with_name_and_options(self):
        opts = FakeProcessGroup.Options()
        opts.fake_option = 42

        mesh = init_device_mesh(
            self.device_type,
            (2, 2, 2),
            mesh_dim_names=("dp", "tp", "cp"),
            backend_override={"tp": opts},
        )

        def get_opts(mesh: DeviceMesh, dim_idx: int) -> C10dBackend.Options:
            return (
                mesh.get_group(dim_idx)
                ._get_backend(torch.device(f"{self.device_type}:{self.rank}"))
                .options
            )

        self.assertIsNone(get_opts(mesh, 0))
        self.assertEqual(get_opts(mesh, 1).fake_option, 42)
        self.assertIsNone(get_opts(mesh, 2))

        dp_tp_mesh = mesh["dp", "tp"]._flatten()
        dp_cp_mesh = mesh["dp", "cp"]._flatten(backend_override=opts)

        self.assertIsNone(get_opts(dp_tp_mesh, 0))
        self.assertEqual(get_opts(dp_cp_mesh, 0).fake_option, 42)

    @with_comms
    def test_backend_override_argument_errors(self):
        with self.assertRaisesRegex(
            RuntimeError,
            "Found redundant dim index 0 and name dp in backend_override",
        ):
            init_device_mesh(
                self.device_type,
                (2, 4),
                mesh_dim_names=("dp", "tp"),
                backend_override={"dp": "foo", 0: "bar"},
            )

        with self.assertRaisesRegex(
            RuntimeError,
            r"Found invalid keys in backend_override: got \['cp'\]",
        ):
            init_device_mesh(
                self.device_type,
                (2, 4),
                mesh_dim_names=("dp", "tp"),
                backend_override={"cp": "foo"},
            )

        with self.assertRaisesRegex(
            RuntimeError,
            r"Found invalid keys in backend_override: got \[42\]",
        ):
            init_device_mesh(
                self.device_type,
                (2, 4),
                mesh_dim_names=("dp", "tp"),
                backend_override={42: "bar"},
            )


class TestDeviceMeshGetItem(DTensorTestBase):
    @property
    def world_size(self):
        return 8

    @with_comms
    def test_raises_no_mesh_dim_found(self):
        with self.assertRaisesRegex(
            RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!"
        ):
            mesh = init_device_mesh(self.device_type, (2, 4))
            mesh["DP"]

    @with_comms
    def test_raises_invalid_mesh_dim_name(self):
        child_mesh_dim_name = ("PP",)
        with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"):
            mesh_dim_names = ("DP", "TP")
            mesh = init_device_mesh(
                self.device_type,
                (2, 4),
                mesh_dim_names=mesh_dim_names,
            )
            mesh[child_mesh_dim_name]

    @with_comms
    def test_get_item_2d(self):
        mesh_shape = (2, 4)
        mesh_dim_names = ("DP", "TP")
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        pg_ranks_by_dim_name = {}
        for mesh_dim_name in mesh_dim_names:
            mesh_dim = mesh_dim_names.index(mesh_dim_name)
            pg_ranks_by_dim_name[mesh_dim_name] = mesh_2d.mesh.swapdims(
                -1, mesh_dim
            ).reshape(-1, mesh_2d.mesh.size(mesh_dim))

        tp_mesh = mesh_2d["TP"]
        tp_group_idx = self.rank // 4
        self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])

        dp_group_idx = self.rank % 4
        self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx])

    @with_comms
    def test_get_item_1d(self):
        mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",))
        # Make sure slicing out 1D mesh from a 1D mesh works.
        dp_mesh = mesh["dp"]
        self.assertEqual(dp_mesh, mesh)

        with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"):
            dp_mesh = mesh["dim0"]

    @with_comms
    def test_get_item_3d(self):
        mesh_shape = (2, 2, 2)
        mesh_dim_names = ("Replicate", "Shard", "TP")
        mesh_3d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]]
        tp_group_idx = int(self.rank / 2)
        self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx])

        shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]]
        shard_group_idx = self.rank % 2 + self.rank // 4 * 2
        self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx])

        replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]]
        replicate_group_idx = self.rank % 4
        self.assertEqual(
            mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx]
        )

        # We support both UX for nD slicing.
        # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"]
        hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]]
        hsdp_mesh_2 = mesh_3d["Replicate", "Shard"]
        hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]]
        hsdp_group_idx = self.rank % 2
        self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx])
        self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx])
        self.assertEqual(hsdp_mesh_1, hsdp_mesh_2)

        # Test slicing out 1D mesh from a sub-2D mesh.
        shard_mesh = hsdp_mesh_2["Shard"]
        self.assertEqual(shard_mesh.mesh.tolist(), shard_group[shard_group_idx])
        replicate_mesh = hsdp_mesh_2["Replicate"]
        self.assertEqual(
            replicate_mesh.mesh.tolist(), replicate_group[replicate_group_idx]
        )

    @with_comms
    def test_cache_and_reuse_submesh_slice_result(self):
        mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp"))

        ref_pg_count = _world.group_count

        # When we call the "dp" slice second time, it should not create any new pg.
        # As we are just using the cached result so the pg count should be the same.
        self.assertEqual(ref_pg_count, _world.group_count)

        # When we call the "tp" slice, it should not create a new pg, as the "tp" slice would
        # just reuse the parent mesh pg.
        mesh["tp"]
        self.assertEqual(_world.group_count, ref_pg_count)

    @with_comms
    def test_get_item_3d_noncontiguous_slicing(self):
        mesh_shape = (2, 2, 2)
        mesh_dim_names = ("dp", "pp", "cp")
        mesh_3d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        # Slice order simply decides which mesh_dim sits on which mesh_dim.
        # For dp_cp_mesh, cp mesh is the innermost dimension.
        dp_cp_mesh = mesh_3d["dp", "cp"]
        expected_mesh_tensor = (
            torch.tensor([[0, 1], [4, 5]], dtype=torch.int)
            if self.rank in (0, 1, 4, 5)
            else torch.tensor([[2, 3], [6, 7]], dtype=torch.int)
        )
        dp_local_rank = dp_cp_mesh.get_local_rank("dp")
        self.assertEqual(dp_cp_mesh.mesh, expected_mesh_tensor)
        cp_mesh = mesh_3d["cp"]
        # Check on the current dp_local_rank, whether the cp mesh tensor is the same.
        self.assertEqual(dp_cp_mesh.mesh[dp_local_rank], cp_mesh.mesh)

        with self.assertRaisesRegex(
            KeyError,
            "Invalid mesh_dim_names",
        ):
            mesh_3d["cp", "dp"]

    @with_comms
    def test_flatten_mesh_1d(self):
        mesh_shape = (4,)
        mesh_dim_names = ("default",)
        mesh_1d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )
        mesh_1d._flatten()

    @with_comms
    def test_flatten_mesh_3d(self):
        mesh_shape = (2, 2, 2)
        mesh_dim_names = ("dp", "cp", "tp")
        mesh_3d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        # Test flatten into an existing mesh_dim_name inside the mesh
        with self.assertRaisesRegex(
            ValueError,
            "already exists for submesh of the DeviceMesh",
        ):
            mesh_3d._flatten("dp")

        # Test flatten contiguous dims
        dp_cp_mesh = mesh_3d["dp", "cp"]
        flattened_dp_cp_mesh = dp_cp_mesh._flatten()
        self.assertEqual(dp_cp_mesh.mesh.flatten(), flattened_dp_cp_mesh.mesh)
        self.assertEqual(flattened_dp_cp_mesh.mesh_dim_names[0], "dp_cp")
        self.assertEqual(flattened_dp_cp_mesh.get_group().group_desc, "mesh_dp_cp")
        root_mesh = dp_cp_mesh._get_root_mesh()
        self.assertEqual(root_mesh, mesh_3d)
        flatten_mesh_layout = root_mesh._flatten_mapping["dp_cp"]._layout
        self.assertEqual(flatten_mesh_layout, flattened_dp_cp_mesh._layout)
        self.assertEqual(
            flattened_dp_cp_mesh._layout.global_ranks(8),
            [[0, 2, 4, 6], [1, 3, 5, 7]],
        )

        ref_pg_count = _world.group_count
        # Calling flatten again should not create a new pg.
        flattened_dp_cp_mesh_2 = dp_cp_mesh._flatten()
        self.assertEqual(flattened_dp_cp_mesh, flattened_dp_cp_mesh_2)
        self.assertEqual(ref_pg_count, _world.group_count)

        # Test flatten non-contiguous dims
        dp_tp_mesh = mesh_3d["dp", "tp"]
        flattened_dp_tp_mesh = dp_tp_mesh._flatten()
        self.assertEqual(dp_tp_mesh.mesh.flatten(), flattened_dp_tp_mesh.mesh)
        self.assertEqual(flattened_dp_tp_mesh.mesh_dim_names[0], "dp_tp")
        root_mesh = dp_tp_mesh._get_root_mesh()
        self.assertEqual(root_mesh, mesh_3d)
        flatten_mesh_root_layout = root_mesh._flatten_mapping["dp_tp"]._layout
        self.assertEqual(flatten_mesh_root_layout, flattened_dp_tp_mesh._layout)
        self.assertEqual(
            flattened_dp_tp_mesh._layout.global_ranks(8),
            [[0, 1, 4, 5], [2, 3, 6, 7]],
        )
        with self.assertRaisesRegex(
            NotImplementedError,
            "Currently, this only allows slicing out a contiguous flattened dim",
        ):
            mesh_3d["dp_tp", "cp"]

        # Test flatten with a flattened mesh_dim_name
        cp_tp_mesh = mesh_3d["cp", "tp"]
        cp_tp_mesh._flatten("dummy")
        self.assertEqual(mesh_3d["dummy"].mesh_dim_names[0], "dummy")

        # Test flatten into an existing mesh_dim_name inside the mesh
        with self.assertRaisesRegex(
            ValueError,
            "dp already exists for submesh of the DeviceMesh",
        ):
            mesh_3d._flatten("dp")
        with self.assertRaisesRegex(
            ValueError,
            "Flatten mesh with mesh_dim_name dp_tp has been created before",
        ):
            mesh_3d["cp", "tp"]._flatten("dp_tp")

    @with_comms(eager_init=True)
    def test_flatten_mesh_4d(self):
        mesh_shape = (2, 2, 2, 1)
        mesh_dim_names = ("dp_replicate", "dp_shard", "cp", "tp")
        mesh_4d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        # flatten HSDP and CP into one mesh
        dp_cp_mesh = mesh_4d[mesh_dim_names[:3]]._flatten("dp_cp")
        # check flattened mesh integrity
        self.assertEqual(mesh_4d["dp_cp"].mesh.flatten(), dp_cp_mesh.mesh)
        # check flattened mesh dim names is correct
        self.assertEqual(dp_cp_mesh.mesh_dim_names, ("dp_cp",))
        # check flattened mesh dependency
        self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_4d)

    @with_comms
    def test_unflatten_mesh_2d(self):
        mesh_shape = (4, 2)
        mesh_dim_names = ("dp", "tp")
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )
        unflatten_mesh = mesh_2d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
        self.assertEqual(
            unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"]
        )
        self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh)
        self.assertEqual(mesh_2d["tp"].get_group(), unflatten_mesh["tp"].get_group())

        # Not supporting slicing out unflatten dim name from root mesh.
        with self.assertRaises(KeyError):
            self.assertEqual(mesh_2d["dp_shard"].mesh, unflatten_mesh["dp_shard"].mesh)

    @with_comms
    def test_unflatten_mesh_3d(self):
        # Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
        global_mesh = init_device_mesh(
            self.device_type,
            (8,),
            mesh_dim_names=("world",),
        )
        non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
        ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
        self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
        self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
        mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp"))
        unflatten_mesh = mesh_3d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
        self.assertEqual(
            unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "cp", "tp"]
        )
        self.assertEqual(mesh_3d["tp"].mesh, unflatten_mesh["tp"].mesh)
        self.assertEqual(mesh_3d["tp"].get_group(), unflatten_mesh["tp"].get_group())
        self.assertEqual(mesh_3d["cp"].mesh, unflatten_mesh["cp"].mesh)
        self.assertEqual(mesh_3d["cp"].get_group(), unflatten_mesh["cp"].get_group())

        # Test unflatten with backend override set.
        if not _NCCL_AVAILABLE:
            return
        opts = dist.ProcessGroupNCCL.Options()
        opts._timeout = timedelta(seconds=30)
        mesh_2d = global_mesh._unflatten(
            0,
            (1, 8),
            ("pp", "spmd"),
            backend_override={"pp": "fake", "spmd": ("nccl", opts)},
        )
        opts = dist.ProcessGroupNCCL.Options()
        opts._timeout = timedelta(seconds=60)
        mesh_4d = mesh_2d._unflatten(
            1,
            (2, 2, 2),
            ("dp", "cp", "tp"),
            backend_override={"dp": "nccl", "cp": "nccl", "tp": ("nccl", opts)},
        )
        self.assertEqual(mesh_4d["pp"].get_group()._get_backend_name(), "custom")
        spmd_pg = mesh_2d["spmd"].get_group()
        self.assertEqual(spmd_pg._get_backend_name(), "nccl")
        w = spmd_pg.allreduce(torch.rand(10).cuda(self.rank))
        self.assertTrue(
            spmd_pg._get_backend(
                torch.device(f"cuda:{self.rank}")
            )._verify_work_timeout(w, timedelta(seconds=30))
        )
        w.wait()
        tp_pg = mesh_4d["tp"].get_group()
        self.assertEqual(tp_pg._get_backend_name(), "nccl")
        w = tp_pg.allreduce(torch.rand(10).cuda(self.rank))
        self.assertTrue(
            tp_pg._get_backend(torch.device(f"cuda:{self.rank}"))._verify_work_timeout(
                w, timedelta(seconds=60)
            )
        )
        w.wait()

    @with_comms
    def test_concatenate_2d(self):
        mesh_shape = (2, 4)
        mesh_dim_names = ("dp", "tp")
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )
        concatenated_mesh = DeviceMesh._concatenate([mesh_2d["dp"], mesh_2d["tp"]])
        self.assertEqual(concatenated_mesh.mesh, mesh_2d.mesh)
        self.assertEqual(concatenated_mesh.get_group("dp"), mesh_2d.get_group("dp"))
        self.assertEqual(concatenated_mesh.get_group("tp"), mesh_2d.get_group("tp"))

    @with_comms
    def test_concatenate_3d(self):
        mesh_shape = (2, 2, 2)
        mesh_dim_names = ("pp", "dp", "tp")
        mesh_3d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )
        concatenated_mesh = DeviceMesh._concatenate([mesh_3d["dp"], mesh_3d["tp"]])
        dp_tp_mesh = mesh_3d["dp", "tp"]
        self.assertEqual(concatenated_mesh.mesh, dp_tp_mesh.mesh)
        self.assertEqual(concatenated_mesh.get_group("dp"), dp_tp_mesh.get_group("dp"))
        self.assertEqual(concatenated_mesh.get_group("tp"), dp_tp_mesh.get_group("tp"))
        self.assertEqual(
            mesh_3d, DeviceMesh._concatenate([mesh_3d["pp", "dp"], mesh_3d["tp"]])
        )

    @with_comms
    def test_reconstruct_mesh_with_flatten_dim(self):
        mesh_3d = init_device_mesh(
            self.device_type, (2, 2, 2), mesh_dim_names=("replicate", "shard", "cp")
        )
        shard_cp_mesh = mesh_3d["shard", "cp"]._flatten()
        hsdp_mesh = mesh_3d["replicate", "shard_cp"]
        expected_mesh_tensor = torch.tensor(
            [[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int
        )
        self.assertEqual(hsdp_mesh.mesh, expected_mesh_tensor)
        self.assertEqual(shard_cp_mesh.get_group(), mesh_3d["shard_cp"].get_group())
        self.assertEqual(
            shard_cp_mesh.get_group(), mesh_3d.get_group(mesh_dim="shard_cp")
        )

        mesh_3d = init_device_mesh(
            self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp")
        )
        dp_cp_mesh = mesh_3d["dp", "cp"]._flatten()
        spmd_mesh = mesh_3d["dp_cp", "tp"]
        expected_mesh_tensor = torch.tensor(
            [[0, 1], [2, 3], [4, 5], [6, 7]], dtype=torch.int
        )
        self.assertEqual(spmd_mesh.mesh, expected_mesh_tensor)
        self.assertEqual(dp_cp_mesh.get_group(), mesh_3d["dp_cp"].get_group())
        self.assertEqual(dp_cp_mesh.get_group(), mesh_3d.get_group(mesh_dim="dp_cp"))


class TestMeshEnv(DTensorTestBase):
    @property
    def world_size(self):
        return 8

    @with_comms
    def test_get_root_mesh(self):
        mesh_3d = init_device_mesh(
            self.device_type,
            (2, 2, 2),
            mesh_dim_names=("dp", "cp", "tp"),
        )

        dp_cp_mesh = mesh_3d["dp", "cp"]
        dp_tp_mesh = mesh_3d["dp", "tp"]
        cp_tp_mesh = mesh_3d["cp", "tp"]
        dp_mesh = mesh_3d["dp"]
        cp_mesh = mesh_3d["cp"]
        tp_mesh = mesh_3d["tp"]
        # Test BC case is still working
        self.assertEqual(_mesh_resources.get_root_mesh(dp_cp_mesh), mesh_3d)
        self.assertEqual(_mesh_resources.get_root_mesh(dp_tp_mesh), mesh_3d)
        self.assertEqual(_mesh_resources.get_root_mesh(cp_tp_mesh), mesh_3d)
        self.assertEqual(_mesh_resources.get_root_mesh(dp_mesh), mesh_3d)
        self.assertEqual(_mesh_resources.get_root_mesh(cp_mesh), mesh_3d)
        self.assertEqual(_mesh_resources.get_root_mesh(tp_mesh), mesh_3d)
        self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_3d)
        self.assertEqual(dp_tp_mesh._get_root_mesh(), mesh_3d)
        self.assertEqual(cp_tp_mesh._get_root_mesh(), mesh_3d)
        self.assertEqual(dp_mesh._get_root_mesh(), mesh_3d)
        self.assertEqual(cp_mesh._get_root_mesh(), mesh_3d)
        self.assertEqual(tp_mesh._get_root_mesh(), mesh_3d)

    @with_comms
    def test_get_root_mesh_dim_exist(self):
        mesh_shape = (2, self.world_size // 2)
        mesh_dim_names = ("DP", "TP")
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        self.assertEqual(mesh_2d["DP"]._get_root_mesh_dim(), 0)
        self.assertEqual(mesh_2d["TP"]._get_root_mesh_dim(), 1)

    @with_comms
    def test_get_root_mesh_dim_not_exist(self):
        mesh_shape = (self.world_size,)
        mesh = init_device_mesh(self.device_type, mesh_shape)

        self.assertEqual(mesh._get_root_mesh_dim(), None)

    @with_comms
    def test_get_mesh_dim_by_name(self):
        mesh_shape = (2, self.world_size // 2)
        mesh_dim_names = ("DP", "TP")
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        self.assertEqual(mesh_2d._get_mesh_dim_by_name("DP"), 0)
        self.assertEqual(mesh_2d._get_mesh_dim_by_name("TP"), 1)

    @with_comms
    def test_get_all_submeshes(self):
        mesh_2d = init_device_mesh(
            self.device_type,
            (2, 4),
            mesh_dim_names=("replicate", "shard"),
        )
        all_submeshes = mesh_2d._get_all_submeshes("replicate")
        self.assertEqual(len(all_submeshes), 4)
        self.assertEqual(
            all(submesh.mesh.numel() == 2 for submesh in all_submeshes), True
        )

    @with_comms
    def test_mesh_slice_fake_tensor_mode(self):
        mesh_shape = (2, self.world_size // 2)
        mesh_dim_names = ("DP", "TP")
        mesh_2d = init_device_mesh(
            self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
        )

        with FakeTensorMode():
            mesh_2d["DP"]
            mesh_2d["TP"]
            mesh_2d["DP", "TP"]


class DeviceMeshCollectiveTest(DTensorTestBase):
    @property
    def world_size(self):
        return 8

    @with_comms
    def test_broadcast_1d(self):
        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
        mesh_broadcast(local_tensor, mesh, mesh_dim=0)
        self.assertEqual(local_tensor, torch.zeros(3, 3))

    @with_comms
    def test_scatter_1d(self):
        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
        scatter_tensor_shape = [3, 3, 3]
        for scatter_dim in range(len(scatter_tensor_shape)):
            shard_placement = Shard(scatter_dim)
            scatter_tensor_shape[scatter_dim] *= self.world_size
            # make the random seed same across rank
            torch.manual_seed(0)
            global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type)
            splitted_list, _ = shard_placement._split_tensor(
                global_tensor, mesh.size(), with_padding=True, contiguous=True
            )
            recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()])
            # scatter on dim > 0 would generate non-contiguous tensor, verify that works
            mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0)
            self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()])

    @with_comms
    def test_scatter_uneven(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
        my_rank = device_mesh.get_rank()
        tensor_to_split = torch.randn(
            device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type
        )

        for shard_dim in range(tensor_to_split.ndim):
            shard_placement = Shard(shard_dim)

            tensor_to_scatter = tensor_to_split.clone()
            tensor_splitted_list = list(
                torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)
            )
            for _ in range(self.world_size - len(tensor_splitted_list)):
                tensor_splitted_list.append(torch.tensor([], device=self.device_type))

            padded_tensor_list, pad_sizes = shard_placement._split_tensor(
                tensor_to_scatter,
                device_mesh.size(),
                with_padding=True,
                contiguous=True,
            )

            scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])
            mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0)

            if pad_sizes[my_rank] != 0:
                scattered_tensor = unpad_tensor(
                    scattered_tensor, shard_dim, pad_sizes[my_rank]
                )

            if scattered_tensor.numel() == 0:
                # We need to check numel() instead of size if a tensor is ([]) after unpadding,
                # since the size could be ([0, 8]) after unpadding.
                self.assertEqual(
                    scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()
                )
            else:
                self.assertEqual(
                    scattered_tensor.size(), tensor_splitted_list[my_rank].size()
                )
                self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank])

    @with_comms
    def test_all_gather_uneven(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
        my_rank = device_mesh.get_rank()
        tensor_to_split = torch.ones(
            device_mesh.size() + 3,
            device_mesh.size() + 1,
            device=self.device_type,
        )

        for shard_dim in range(tensor_to_split.ndim):
            shard_placement = Shard(shard_dim)
            tensor_padded_list, pad_sizes = shard_placement._split_tensor(
                tensor_to_split,
                device_mesh.size(),
                with_padding=True,
                contiguous=True,
            )
            local_tensor = tensor_padded_list[my_rank]
            big_tensor = funcol.all_gather_tensor(
                local_tensor, gather_dim=shard_dim, group=(device_mesh, 0)
            )
            big_tensor_chunks = list(
                torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim)
            )
            unpadded_list = [
                (
                    unpad_tensor(big_tensor, shard_dim, pad_sizes[i])
                    if pad_sizes[i] > 0
                    else big_tensor
                )
                for i, big_tensor in enumerate(big_tensor_chunks)
            ]
            all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim)

            self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size())
            self.assertEqual(all_gathered_tensor, tensor_to_split)

    @with_comms
    def test_reduce_scatter_contiguous(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
        my_rank = device_mesh.get_rank()

        # Init the tensor
        step = self.world_size * 2
        total_elem = step**2
        tensor = torch.arange(0, total_elem).view(step, -1).to(device=self.device_type)
        tensor = tensor * (my_rank + 1)

        # Get non-contiguous tensor by slicing
        tensor_to_reduce = tensor[::2, :2]
        tensor_contiguous = tensor_to_reduce.clone().contiguous()

        # Partial to Shard to trigger reduce_scatter
        tensor_to_reduce = DTensor.from_local(
            tensor_to_reduce, device_mesh, [_Partial()]
        )
        tensor_contiguous = DTensor.from_local(
            tensor_contiguous, device_mesh, [_Partial()]
        )
        new_tensor = tensor_to_reduce.redistribute(device_mesh, [Shard(0)])
        new_tensor_contiguous = tensor_contiguous.redistribute(device_mesh, [Shard(0)])

        # The output for contiguous and non-contiguous tensors of the same value
        # should return the same reducescatter value.
        new_tensor_local = new_tensor._local_tensor
        new_tensor_contiguous_local = new_tensor_contiguous._local_tensor
        self.assertEqual(new_tensor_local, new_tensor_contiguous_local)
        self.assertEqual(list(new_tensor_local.size()), [1, 2])

        # Check the reduce numerical value
        sum_base = (1 + self.world_size) * self.world_size / 2
        first_elem = my_rank * sum_base * step * 2
        expected_tensor = torch.tensor(
            [[first_elem, first_elem + sum_base]],
            dtype=new_tensor_local.dtype,
            device=self.device_type,
        )
        self.assertEqual(new_tensor_local, expected_tensor)

    @with_comms
    def test_reduce_scatter_uneven(self):
        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
        my_rank = device_mesh.get_rank()
        tensor_to_split = (
            torch.ones(
                device_mesh.size() + 3,
                device_mesh.size() + 1,
                device=self.device_type,
            )
            * self.rank
        )

        for shard_dim in range(tensor_to_split.ndim):
            shard_placement = Shard(shard_dim)
            tensor_to_scatter = tensor_to_split.clone()

            tensor_splitted_list = list(
                torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)
            )
            for _ in range(self.world_size - len(tensor_splitted_list)):
                tensor_splitted_list.append(torch.tensor([], device=self.device_type))

            padded_tensor_list, pad_sizes = shard_placement._split_tensor(
                tensor_to_scatter,
                device_mesh.size(),
                with_padding=True,
                contiguous=True,
            )

            tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim)

            res_num = ((0 + self.world_size - 1) * self.world_size) / 2

            scattered_tensor = funcol.reduce_scatter_tensor(
                tensor_to_reduce,
                reduceOp="sum",
                scatter_dim=shard_dim,
                group=(device_mesh, 0),
            )

            # unpad scattered_tensor
            if pad_sizes[my_rank] > 0:
                scattered_tensor = unpad_tensor(
                    scattered_tensor, shard_dim, pad_sizes[my_rank]
                )

            if scattered_tensor.numel() == 0:
                # We need to check numel() instead of size if a tensor is ([]) after unpadding,
                # since the size could be ([0, 8]) after unpadding.
                self.assertEqual(
                    scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()
                )
            else:
                self.assertEqual(
                    scattered_tensor.size(), tensor_splitted_list[my_rank].size()
                )
                self.assertEqual(
                    scattered_tensor,
                    torch.ones_like(tensor_splitted_list[my_rank]) * res_num,
                )

    @with_comms
    def test_broadcast_nd(self):
        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
        mesh = DeviceMesh(self.device_type, mesh_tensor)
        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank

        # check all dim groups
        dim_to_subgroups = mesh.get_all_groups()
        for dim, dim_group in enumerate(dim_to_subgroups):
            dim_group_size = get_world_size(dim_group)
            global_ranks = [
                get_global_rank(dim_group, i) for i in range(dim_group_size)
            ]
            cloned_local_tensor = local_tensor.clone()
            mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim)
            res_num = global_ranks[0]
            self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)

    @with_comms
    def test_scatter_nd(self):
        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
        mesh = DeviceMesh(self.device_type, mesh_tensor)

        # check all dim groups
        dim_to_subgroups = mesh.get_all_groups()
        for dim, dim_group in enumerate(dim_to_subgroups):
            dim_group_size = get_world_size(dim_group)
            global_ranks = [
                get_global_rank(dim_group, i) for i in range(dim_group_size)
            ]
            scattered_tensors = [
                torch.ones(3, 3, device=self.device_type) * global_rank
                for global_rank in global_ranks
            ]
            received_tensor = torch.empty_like(
                scattered_tensors[mesh.get_coordinate()[dim]]
            )
            mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim)
            self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)


class CuTeLayoutTest(TestCase):
    def test_coalesce(self):
        # ((3,2),(2,1)) -> (6,1)
        l = _Layout((3, 2), (2, 1))
        l = l.coalesce()
        self.assertEqual(list(l.sizes_and_strides), [(6, 1)])

        # ((2,12),(3,4),(4,1)) -> (24,1)
        l = _Layout((2, 3, 4), (12, 4, 1))
        l = l.coalesce()
        self.assertEqual(list(l.sizes_and_strides), [(24, 1)])

    def test_coalesce_non_coalescible(self):
        # ((3,4),(2,1)) stays as-is (4 ≠ 2*1)
        l = _Layout((3, 2), (4, 1))
        l = l.coalesce()
        self.assertEqual(list(l.sizes_and_strides), [(3, 4), (2, 1)])

    def test_complement_n_group_layout(self):
        # complement((4,2), 8) = (2,1); together form (8,1)
        pg_layout = _Layout(
            (4,),
            (2,),
        )
        outer = pg_layout.complement(world_size=8)
        self.assertEqual(list(outer.sizes_and_strides), [(2, 1)])
        self.assertEqual(
            pg_layout.all_ranks_from_zero(),
            [0, 2, 4, 6],
        )
        groups = [
            [o + i for i in pg_layout.all_ranks_from_zero()]
            for o in outer.all_ranks_from_zero()
        ]
        self.assertEqual(
            groups,
            [
                [0, 2, 4, 6],
                [1, 3, 5, 7],
            ],
        )
        self.assertEqual(
            pg_layout.global_ranks(8),
            [
                [0, 2, 4, 6],
                [1, 3, 5, 7],
            ],
        )
        # complement((4,2), 16) = ((2,8), (2,1)); together form (16,1)
        outer = pg_layout.complement(world_size=16)
        self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)])
        self.assertEqual(
            outer.all_ranks_from_zero(),
            [0, 1, 8, 9],
        )
        self.assertEqual(
            pg_layout.global_ranks(16),
            [
                [0, 2, 4, 6],
                [1, 3, 5, 7],
                [8, 10, 12, 14],
                [9, 11, 13, 15],
            ],
        )

        # Complement ((2,4), (2,1)) under world_size=16 → complement ((2,8), (2,2))
        pg_layout = _Layout((2, 2), (4, 1))
        self.assertEqual(
            pg_layout.all_ranks_from_zero(),
            [0, 1, 4, 5],
        )
        outer = pg_layout.complement(world_size=16)
        self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 2)])
        self.assertEqual(
            outer.all_ranks_from_zero(),
            [0, 2, 8, 10],
        )
        self.assertEqual(
            pg_layout.global_ranks(16),
            [
                [0, 1, 4, 5],
                [2, 3, 6, 7],
                [8, 9, 12, 13],
                [10, 11, 14, 15],
            ],
        )

        # Test layout_to_global_ranks and layout_to_all_ranks_from_zero
        pg_layout = _Layout((2, 2), (4, 2))
        self.assertEqual(
            pg_layout.all_ranks_from_zero(),
            [0, 2, 4, 6],
        )
        self.assertEqual(
            pg_layout.global_ranks(16),
            [
                [0, 2, 4, 6],
                [1, 3, 5, 7],
                [8, 10, 12, 14],
                [9, 11, 13, 15],
            ],
        )
        outer = pg_layout.complement(world_size=16)
        self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)])
        # Test when stride is not monotonically decreasing, the complement layout
        # is same as the one sorted its stride.
        pg_layout_r = _Layout((2, 2), (2, 4))
        outer = pg_layout_r.complement(world_size=16)
        self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)])
        self.assertEqual(
            pg_layout_r.global_ranks(16),
            [
                [0, 4, 2, 6],
                [1, 5, 3, 7],
                [8, 12, 10, 14],
                [9, 13, 11, 15],
            ],
        )

        # Test just all_ranks_from_zero and global_ranks.
        pg_layout = _Layout((4,), (2,))
        self.assertEqual(
            pg_layout.all_ranks_from_zero(),
            [0, 2, 4, 6],
        )
        self.assertEqual(
            pg_layout.global_ranks(16),
            [
                [0, 2, 4, 6],
                [1, 3, 5, 7],
                [8, 10, 12, 14],
                [9, 11, 13, 15],
            ],
        )

    def test_composition(self):
        # self = ((4,2), (2,1)), l = (2,1)  → self o l = (2,1)
        orig_l = _Layout((4, 2), (2, 1))
        right_l = _Layout((2,), (1,))
        composed_layout = orig_l.composition(right_l)
        self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 1)])
        self.assertEqual(
            composed_layout.global_ranks(8),
            [
                [0, 1],
                [2, 3],
                [4, 5],
                [6, 7],
            ],
        )

        # self = (4,2), l = (2,1)  → self o l = (2,2)
        orig_l = _Layout((4,), (2,))
        right_l = _Layout((2,), (1,))
        composed_layout = orig_l.composition(right_l)
        self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 2)])
        self.assertEqual(
            composed_layout.global_ranks(8),
            [
                [0, 2],
                [1, 3],
                [4, 6],
                [5, 7],
            ],
        )

        # self = (4,2), l = ((2,2), (2,1))  → self o l = ((2,4), (2,2))
        # This is to mimic the un-flatten from a 2D mesh to a 1D mesh.
        right_l = _Layout((2, 2), (2, 1))
        composed_layout = orig_l.composition(right_l)
        self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 4), (2, 2)])
        self.assertEqual(
            composed_layout[0].global_ranks(8),
            [
                [0, 4],
                [1, 5],
                [2, 6],
                [3, 7],
            ],
        )
        self.assertEqual(
            composed_layout[1].global_ranks(8),
            [
                [0, 2],
                [1, 3],
                [4, 6],
                [5, 7],
            ],
        )

        # Error case.
        orig_l = _Layout((4, 2), (4, 1))
        with self.assertRaises(
            AssertionError,
        ):
            right_l = _Layout((2,), (3,))
            orig_l.composition(right_l)

    def test_check_non_overlap(self):
        """Test the check_non_overlap method for various layout configurations."""
        # Test 1: Valid layout - no overlap
        # sizes=(2,3), strides=(6,1) - stride 6 > span 3, so no overlap
        layout1 = _Layout((2, 3), (6, 1))
        self.assertTrue(layout1.check_non_overlap())

        # Test 2: Invalid layout - overlap due to stride < previous span
        # sizes=(2,3), strides=(2,1) - stride 2 < span 3, causes overlap
        layout2 = _Layout((2, 3), (2, 1))
        self.assertFalse(layout2.check_non_overlap())

        # Test 3: Invalid layout - duplicate strides
        # sizes=(2,3), strides=(1,1) - same stride, causes overlap
        layout3 = _Layout((2, 3), (1, 1))
        self.assertFalse(layout3.check_non_overlap())

        # Test 4: Valid layout - single dimension
        layout4 = _Layout((4,), (1,))
        self.assertTrue(layout4.check_non_overlap())

        # Test 5: Valid layout - exact boundary case
        # sizes=(2,3), strides=(3,1) - stride 3 == span 3, valid
        layout5 = _Layout((2, 3), (3, 1))
        self.assertTrue(layout5.check_non_overlap())

        # Test 6: Valid layout - multi-dimensional with proper spacing
        layout6 = _Layout((2, 2, 2), (8, 4, 1))
        self.assertTrue(layout6.check_non_overlap())

        # Test 7: Valid layout - stride not ordered
        layout7 = _Layout((2, 2, 2), (4, 1, 2))
        self.assertTrue(layout7.check_non_overlap())

        # Test 8: Valid layout - Interleaved but no overlap
        layout8 = _Layout((3, 2), (2, 3))
        self.assertTrue(layout8.check_non_overlap())

    def test_remap_to_tensor(self):
        """Test the remap_to_tensor method for various scenarios."""
        # Test 1: Consecutive ranks, full world - should return logical groups directly
        original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
        layout1 = _Layout((2, 2), (2, 1))  # row-major 2x2
        result1 = layout1.remap_to_tensor(original_mesh)
        expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
        self.assertEqual(result1, expected1)

        # Test 2: Non-consecutive ranks - should map to actual ranks
        original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
        layout2 = _Layout((2, 2), (2, 1))
        result2 = layout2.remap_to_tensor(original_mesh)
        expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
        self.assertEqual(result2, expected2)

        # Test 4: 1D layout with consecutive ranks
        original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
        layout4 = _Layout((4,), (1,))
        result4 = layout4.remap_to_tensor(original_mesh)
        expected4 = torch.tensor([[0, 1, 2, 3]], dtype=torch.int)
        self.assertEqual(result4, expected4)

        # Test 5: Complex strided layout with non-consecutive ranks
        original_mesh = torch.tensor([5, 10, 15, 20], dtype=torch.int)
        layout5 = _Layout((2, 2), (2, 1))
        result5 = layout5.remap_to_tensor(original_mesh)
        expected5 = torch.tensor([[[5, 10], [15, 20]]], dtype=torch.int)
        self.assertEqual(result5, expected5)

        # Test 6: Tensor Cute representation of a 2D mesh
        original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
        layout6 = _Layout((2, 2), (1, 2))  # column-major style
        result6 = layout6.remap_to_tensor(original_mesh)
        expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
        self.assertEqual(result6, expected6)

        # Test 7: Layout with different stride pattern
        original_mesh = torch.tensor([0, 2, 1, 4], dtype=torch.int)
        layout7 = _Layout((2, 2), (1, 2))  # column-major style
        result7 = layout7.remap_to_tensor(original_mesh)
        expected7 = torch.tensor([[[0, 1], [2, 4]]], dtype=torch.int)
        self.assertEqual(result7, expected7)


if __name__ == "__main__":
    run_tests()
